Skip to content

Conversation

@yweng0828
Copy link
Collaborator

@yweng0828 yweng0828 commented Oct 22, 2025

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced tree-based speculative decoding with improved resource management and batch-aware draft token processing.
    • Added advanced test case for Eagle3 decoding workflow.
  • Bug Fixes

    • Fixed KV cache alignment by narrowing past key-value length calculations to current request set.
  • Refactor

    • Reworked draft token handling to support static and dynamic tree decoding paths more cleanly.
    • Streamlined spec-decoding configuration and initialization logic.
  • Tests

    • Added comprehensive test coverage for tree-based draft token generation.

Description

In this PR, we implemented the runtime logic for the draft token tree. Given the improved performance of capturable drafting loops (CDL), our implementation is also based on CDL. A non-CDL draft PR is available here, but it's not considered for merging: #8109

With this PR, we now have the following features:

  1. After the target model completes the prefill phase, the draft model will generate draft tokens based on the static tree.
  2. These draft tokens are passed to the target model and forwarded using XQA.
  3. Validate which of these draft tokens can be accepted.
  4. Generate draft tokens for the first generation step.

Unverified:

  1. A new round of validation requires KV cache rewind for the target model (https://github.com/NVIDIA/TensorRT-LLM/pull/8421/files). This will be further tested and updated in a subsequent PR.
  2. trtllm-bench test
  3. disagg test

Required tests for this PR before merging:
[x] Verify that the current implementation is compatible with CUDA Graph
[x] Verify that this PR does not impact other existing functionality

For detailed changes, please refer to the image below:

draft_tokens_tree_details

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@yweng0828 yweng0828 requested a review from mikeiovine October 22, 2025 15:12
@yweng0828 yweng0828 force-pushed the yweng/add_draft_token_tree_runtime_on_cdl branch from b0c522d to 040bc9a Compare October 23, 2025 06:11
@yweng0828 yweng0828 marked this pull request as ready for review October 26, 2025 13:40
@yweng0828 yweng0828 requested review from a team as code owners October 26, 2025 13:40
@yweng0828 yweng0828 requested review from hlu1 and removed request for hlu1 October 26, 2025 13:40
@yweng0828
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22532 [ run ] triggered by Bot. Commit: ef8b2b6

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 26, 2025

📝 Walkthrough

Walkthrough

This pull request introduces comprehensive tree-based speculative decoding enhancements across TensorRT-LLM. The changes refactor attention metadata interfaces, drafter model logic, and spec-decoding parameter handling to support dynamic tree-based token generation with explicit batch-level and resource management, including new buffer allocation strategies, tree-aware position tracking, and revised sampling paths.

Changes

Cohort / File(s) Summary
Attention Backend Refactoring
tensorrt_llm/_torch/attention_backend/interface.py, tensorrt_llm/_torch/attention_backend/trtllm.py
Updated update_spec_dec_param signature from is_spec_dec_tree/is_spec_dec_dynamic_tree flags to batch-aware parameters (batch_size, is_spec_decoding_enabled, spec_metadata, spec_tree_manager, max_draft_len, max_total_draft_tokens). Reworked internal logic to decouple spec-dec mode from fixed flags and conditionally populate buffers (position_offsets, packed_mask) based on tree type; introduced parameterized helper methods for generating spec-decoding tensors.
PyExecutor Integration
tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/pyexecutor/resource_manager.py, tensorrt_llm/_torch/pyexecutor/sampler.py
Extended PyTorchModelEngine to propagate resource_manager through input-preparation paths; integrated Eagle3SpecMetadata handling and added tracking of request_accepted_path for per-request draft token acceptance. Updated _prepare_tp_inputs signature to accept resource manager. Refined KV cache length slicing in resource_manager and introduced py_num_accepted_draft_tokens_indices tracking in sampler for accepted draft paths; eliminated batch tree-sampling path.
Speculative Decoding Core
tensorrt_llm/_torch/speculative/eagle3.py, tensorrt_llm/_torch/speculative/interface.py, tensorrt_llm/_torch/speculative/spec_tree_manager.py
Enhanced Eagle3SpecMetadata with request_accepted_path field and revised hidden-state read/write selection logic to use accepted-path indices. Expanded attention_need_spec_dec_mode with four explicit use-case branches. Substantially reworked SpecTreeManager: added persistent buffers (spec_dec_packed_mask, spec_dec_position_offsets, top_k_list_cuda, tokens_gather_idx, etc.); replaced compute_spec_dec_pack_mask with compute_spec_dec_packed_mask; expanded static-tree initialization with index mappings and node lists; introduced drafter-model-specific buffers for masks, offsets, and hidden-state indices.
Drafting Logic
tensorrt_llm/_torch/speculative/drafting_loops.py, tensorrt_llm/_torch/speculative/drafter.py, tensorrt_llm/_torch/speculative/model_drafter.py
Added prepare_for_generation_with_tree_decoding helper function to assemble inputs for tree-based drafting. Updated ChainDrafter.sample signature to accept optional spec_tree_manager and added tree-based sampling branch with per-layer top-k. Modified model_drafter to propagate accepted draft token indices and use max_total_draft_tokens for buffer allocation and accumulation ranges.
Tests
tests/integration/defs/test_e2e.py, tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py, tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
Added new e2e test for Eagle3 tree-based decoding. Introduced comprehensive unittest for prepare_for_generation_with_tree_decoding with mock Eagle3 metadata and tree-manager objects, verifying attention and spec metadata state across batch/layer configurations. Refactored tree-sampling test to use ChainDrafter directly instead of TorchSampler pathway, introducing lightweight DummyModel and simplifying control flow.

Sequence Diagram(s)

sequenceDiagram
    participant Engine as PyTorchModelEngine
    participant AttMeta as AttentionMetadata
    participant SpecMeta as Eagle3SpecMetadata
    participant TreeMgr as SpecTreeManager
    participant Drafter as ChainDrafter
    
    rect rgb(240, 248, 255)
    Note over Engine,Drafter: Tree-Based Speculative Decoding Flow (New)
    end
    
    Engine->>Engine: _prepare_tp_inputs(resource_manager)
    Engine->>SpecMeta: Set request_accepted_path
    
    Engine->>AttMeta: update_spec_dec_param(batch_size, spec_metadata, spec_tree_manager, ...)
    AttMeta->>TreeMgr: Retrieve tree structure & buffers
    
    alt Static Tree Path
        TreeMgr->>AttMeta: Copy spec_dec_packed_mask, position_offsets
        AttMeta->>AttMeta: Populate kv_lens_cuda, seq_lens
    else Dynamic Tree Path
        AttMeta->>AttMeta: Initialize placeholders for dynamic updates
    end
    
    Engine->>Drafter: forward() with spec_tree_manager
    Drafter->>TreeMgr: get_generation_lengths(), get_masks(), get_offsets()
    Drafter->>Drafter: sample(draft_layer_idx, logits, spec_tree_manager)
    
    alt Tree Sampling
        Drafter->>TreeMgr: Retrieve per-layer top_k_list
        Drafter->>Drafter: Top-k sampling with tree constraints
    else Linear Sampling
        Drafter->>Drafter: Greedy/standard sampling
    end
    
    Drafter-->>Engine: Draft tokens with tree indices
    Engine->>Engine: prepare_for_generation_with_tree_decoding()
    Engine->>AttMeta: Update position_ids, masks, and indices per layer
    Engine->>SpecMeta: Update gather indices and hidden-state read/write offsets
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

Areas requiring extra attention:

  • SpecTreeManager buffer management and initialization (spec_tree_manager.py): Extensive additions of static-tree buffers (tokens_gather_idx, logits_gather_idx, drafter-model-specific masks/offsets) require careful validation of shape consistency and initialization order, especially the transition between dynamic and static tree modes.
  • AttentionMetadata parameter flow (trtllm.py): The reworked update_spec_dec_param logic branches on static vs. dynamic tree types and populates multiple buffers conditionally; verify correctness of buffer population and alignment with downstream usage in attention kernels.
  • Eagle3SpecMetadata accepted-path logic (eagle3.py): The new request_accepted_path tracking and revised read/write index computation must be validated against actual request sequences and KV cache addressing to ensure no off-by-one or alignment errors.
  • ChainDrafter tree-based sampling (drafting_loops.py): The new prepare_for_generation_with_tree_decoding and tree-sampling paths introduce new dependencies on spec_tree_manager; ensure correct tensor reshaping, offset application, and layer iteration.
  • PyExecutor resource propagation (model_engine.py): Verify that resource_manager is correctly threaded through all input-preparation and warmup paths, and that Eagle3SpecMetadata instance checks and request_accepted_path assignments are complete and non-breaking.

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The pull request description provides a clear explanation of the implemented features and references a detailed image of the changes. The description documents four key features implemented, lists unverified items requiring future work, and includes a completed PR checklist. However, the Test Coverage section is completely empty—the template requires authors to "list clearly what are the relevant test(s) that can safeguard the changes," yet no tests are listed in this mandatory section. While the raw summary shows several new tests were added (test_draft_token_tree_quickstart_advanced_eagle3, test_draft_token_static_tree_prepare_for_generation, etc.), these are not documented in the PR description's Test Coverage section as required by the template.
Docstring Coverage ⚠️ Warning Docstring coverage is 26.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title "[TRTLLM-8160][feat] Add draft token tree runtime on CDL" directly aligns with the primary objective of the changeset, which is to implement runtime logic for draft token tree using capturable drafting loops. The title is concise, specific, and follows the required format with a valid JIRA ticket ID (TRTLLM-8160) and an appropriate type label (feat). It clearly communicates the main change without vague terminology or excessive detail.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 23

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

773-799: Fix device/dtype mismatch in advanced indexing, convert scalar tensor to int, and ensure tensor-to-list conversion.

Three bugs confirmed:

  1. Indexing dtype mismatch (lines 792–798): eagle_paths is int32; advanced indexing CPU tensors requires int64 indices.
  2. Tensor-to-bool ambiguity (line 801): cur_accepted_len is a 0-D tensor; must convert to Python int before comparison.
  3. Tensor-to-list assignment (line 825–826): Assigning int32 tensor slice to list slice stores tensor objects instead of integers; must call .tolist().

Apply the diff:

-            all_draft_tokens = torch.tensor(request.py_draft_tokens)  # [max_total_draft_tokens]
-            all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze(
-                -1
-            )  # [max_total_draft_tokens]
+            # Host-side CPU tensors, ensure long dtype for indexing
+            all_draft_tokens = torch.as_tensor(request.py_draft_tokens, dtype=torch.long, device="cpu")
+            all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze(-1).to(dtype=torch.long, device="cpu")  # [max_total_draft_tokens + 1]
@@
-            for path_idx, path in enumerate(eagle_paths):
-                path_exclude_root = (
-                    path[1:] - 1
-                )  # [max_draft_len], '[1:]' since the new_tokens does not contain the root node.
-                # '-1' is the index shift after exclude the root node.
-                draft_tokens_indices = path_exclude_root[path_exclude_root >= 0]  # [max_draft_len]
-                target_tokens_indices = path[path >= 0]  # [max_draft_len + 1]
+            for path_idx, path in enumerate(eagle_paths):
+                # Convert to long for CPU advanced indexing
+                path_long = path.to(dtype=torch.long)
+                path_exclude_root = path_long[1:] - 1  # exclude root; -1 index shift
+                draft_tokens_indices = path_exclude_root[path_exclude_root >= 0]
+                target_tokens_indices = path_long[path_long >= 0]
@@
-                cur_draft_tokens = all_draft_tokens[draft_tokens_indices]
-                cur_target_tokens = all_target_tokens[target_tokens_indices]
+                cur_draft_tokens = all_draft_tokens.index_select(0, draft_tokens_indices)
+                cur_target_tokens = all_target_tokens.index_select(0, target_tokens_indices)
@@
-                cur_accepted_len = torch.cumprod(
-                    (cur_draft_tokens == cur_target_tokens[:-1]).int(), dim=-1
-                ).sum()
-
-                # Accepted one more token from the target model.
-                cur_accepted_len += 1
-
-                if cur_accepted_len > longest_accepted_len:
+                cur_accepted_len = int(torch.cumprod(
+                    (cur_draft_tokens == cur_target_tokens[:-1]).to(torch.int32), dim=-1
+                ).sum().item()) + 1  # +1 accounts for root
+
+                if cur_accepted_len > longest_accepted_len:
                     longest_accepted_len = cur_accepted_len
                     longest_match_path_idx = path_idx
@@
-                request.py_num_accepted_draft_tokens_indices[: num_accepted_draft_tokens - 1] = (
-                    eagle_paths[longest_match_path_idx][1:longest_accepted_len]
-                )  # exclude the root node
+                accepted_indices = eagle_paths[longest_match_path_idx][1:longest_accepted_len].tolist()
+                request.py_num_accepted_draft_tokens_indices[: num_accepted_draft_tokens - 1] = accepted_indices  # exclude root
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)

1-4: Missing NVIDIA Apache-2.0 header (2025)

Per repo guidelines, prepend the standard header.

Apply at file start:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+

As per coding guidelines

tensorrt_llm/_torch/speculative/drafting_loops.py (1)

1-1: Missing NVIDIA Apache-2.0 header

Add the required NVIDIA Apache-2.0 header (year 2025).

tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

1-1: Missing NVIDIA Apache-2.0 header

Add the required NVIDIA Apache-2.0 header (year 2025).

tensorrt_llm/_torch/attention_backend/trtllm.py (1)

1-1: Add required NVIDIA Apache-2.0 header (2025).

File is missing the mandatory license header. Please prepend it.

Apply this diff:

+# Copyright (c) 2025, NVIDIA CORPORATION.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)

310-313: The dynamic path indexing bug is confirmed. The code attempts 3D indexing ([:, i, :]) on a 2D tensor (eagle_paths[tree_idx] is shape [max_total_draft_tokens + 1, max_draft_len + 1]), which causes a shape mismatch at assignment. The proposed fix correctly reshapes the nonzero indices to 1D and assigns them row-wise to the 2D tensor.

🧹 Nitpick comments (13)
tensorrt_llm/_torch/speculative/drafter.py (1)

67-67: Remove useless expression.

self.max_total_draft_tokens is a no-op here. Drop it to satisfy linters and avoid confusion.

-            self.max_total_draft_tokens
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)

584-585: Ensure contiguous int32 slice for KV lengths before passing to C++ op.

Slicing returns a view; make it contiguous and int32 to match extension expectations.

-        past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)]
+        past_key_value_lengths = (
+            attn_metadata.kv_lens_cuda.narrow(0, 0, len(requests)).to(torch.int32).contiguous()
+        )

Confirm torch.ops.tensorrt_llm.update_kv_cache_draft_token_location expects int32 on the same device as other KV tensors.

tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (3)

15-25: Make DummyModel.forward fail fast

Use explicit NotImplementedError to surface unintended calls during refactors.

 class DummyModel(torch.nn.Module):
@@
-    def forward(self, *args, **kwargs) -> torch.Tensor:
-        pass
+    def forward(self, *args, **kwargs) -> torch.Tensor:
+        raise NotImplementedError("DummyModel.forward should not be called in this unit test")

54-60: Decouple from external model roots to keep unit test hermetic

Avoid requiring llm_models_root for a path that is not used. Consider a benign default like os.environ.get("LLM_MODELS_ROOT", "/tmp") or pass a dummy path.

-        spec_config = EagleDecodingConfig(
+        spec_config = EagleDecodingConfig(
             max_draft_len=max_draft_len,
             max_total_draft_tokens=max_total_draft_tokens,
-            speculative_model_dir=eagle_model_dir,
+            speculative_model_dir=os.environ.get("LLM_MODELS_ROOT", "/tmp"),

61-67: Assertion style and device selection nits

  • Prefer torch.equal(output_tokens, ref_new_tokens) for clarity.
  • Derive ref tensor device from logits.device to avoid hardcoding CUDA.
-        assert torch.all(output_tokens == ref_new_tokens)
+        assert torch.equal(output_tokens, ref_new_tokens)

And when constructing ref_new_tokens:

-    ref_new_tokens = torch.tensor([...], device='cuda')
+    ref_new_tokens = torch.tensor([...], device=logits.device)
tests/integration/defs/test_e2e.py (1)

2060-2093: Refactor eagle_choices string construction for clarity; remove memory-guard suggestion

The --eagle_choices flag is confirmed as supported in quickstart_advanced.py (type=str, default=None). However, refactor the eagle_choices JSON construction using json.dumps() for consistency with existing codebase patterns (e.g., test_e2e.py line 709) and to reduce manual JSON string errors.

The memory-guard suggestion (skipif marker) is unnecessary—the test suite consistently validates memory requirements post-execution via _check_mem_usage(), which is already present and correct in this test (_check_mem_usage(running_log, [27, 0, 0, 0])).

tensorrt_llm/_torch/speculative/eagle3.py (2)

174-178: Ensure paired iterables are same length

Add an assertion before the loop to guarantee request_ids and seq_lens have equal length (useful under Py3.8 where zip(strict=...) is unavailable).

@@
         if not self.is_draft_model:
-            for req_id, seq_len in zip(self.request_ids, self.seq_lens):
+            assert len(self.request_ids) == len(self.seq_lens), \
+                "request_ids and seq_lens must be the same length"
+            for req_id, seq_len in zip(self.request_ids, self.seq_lens):

197-201: Replace fullwidth parenthesis in comment

Use ASCII ) to avoid lint failures (RUF003).

-            # 2)is_first_draft
+            # 2) is_first_draft
tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py (3)

17-17: Avoid mutating sys.path in tests

This path hack is brittle in CI. Prefer relying on the test runner’s import paths.

-sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
+# Avoid mutating sys.path; rely on test runner configuration.

6-8: Remove unused model path plumbing

llm_models_root()/eagle_model_dir are not needed; pass a dummy string to EagleDecodingConfig to decouple from local assets.

-from utils.llm_data import llm_models_root
@@
-    models_path = llm_models_root()
-    eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"  # It will not actually be used.
+    eagle_model_dir = "unused"

Also applies to: 22-24


662-663: Unnecessary unittest entrypoint

This is a pytest-style function test. unittest.main() won’t discover it; safe to drop to avoid confusion.

-if __name__ == "__main__":
-    unittest.main()
+# Intentionally no unittest entrypoint; use pytest discovery.
tensorrt_llm/_torch/speculative/drafting_loops.py (1)

145-147: Typo in comment

“toshift” → “to shift”.

-        1] - 1  # shape: [next_layer_gen_len_per_req]. -1 is toshift the root node
+        1] - 1  # shape: [next_layer_gen_len_per_req]. -1 is to shift the root node
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)

324-364: Optional: simplify packed-mask computation with bitshifts.

Avoid pow on int tensors and repeated reshape rebinds; use bit operations for clarity and speed.

Apply this diff:

-        num_blocks = math.ceil((self.max_total_draft_tokens + 1) / 32)
-        int_tensor = mask_matrix.reshape(
-            -1, num_process_tokens
-        )  # shape: [num_trees * num_process_tokens, num_process_tokens]
-        packed_mask = packed_mask.reshape(
-            -1,
-            num_blocks)  # shape: [num_trees * num_process_tokens, num_blocks]
-
-        for block_idx in range(num_blocks):
-            start_idx = block_idx * 32
-            end_idx = min(start_idx + 32, num_process_tokens)
-            if end_idx < start_idx:
-                break
-            block_bits = int_tensor[:, start_idx:end_idx]
-            weight = torch.pow(
-                2,
-                torch.arange(end_idx - start_idx,
-                             dtype=torch.int32,
-                             device=int_tensor.device))
-            block_value = torch.sum(block_bits * weight, dim=-1)
-            packed_mask[:, block_idx] = block_value
-
-        packed_mask = packed_mask.reshape(num_trees, num_process_tokens,
-                                          num_blocks)
+        num_blocks = math.ceil((self.max_total_draft_tokens + 1) / 32)
+        rows = mask_matrix.reshape(-1, num_process_tokens)
+        out = packed_mask.reshape(-1, num_blocks)
+        for block_idx in range(num_blocks):
+            start = block_idx * 32
+            end = min(start + 32, num_process_tokens)
+            if end <= start:
+                break
+            span = end - start
+            weights = (torch.ones(span, dtype=torch.int32, device=rows.device) << torch.arange(span, dtype=torch.int32, device=rows.device))
+            out[:, block_idx] = (rows[:, start:end].to(torch.int32) * weights).sum(dim=-1)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2956978 and ef8b2b6.

📒 Files selected for processing (15)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (2 hunks)
  • tensorrt_llm/_torch/attention_backend/interface.py (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (18 hunks)
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (3 hunks)
  • tensorrt_llm/_torch/speculative/drafter.py (1 hunks)
  • tensorrt_llm/_torch/speculative/drafting_loops.py (3 hunks)
  • tensorrt_llm/_torch/speculative/eagle3.py (4 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/model_drafter.py (3 hunks)
  • tensorrt_llm/_torch/speculative/spec_tree_manager.py (7 hunks)
  • tests/integration/defs/test_e2e.py (1 hunks)
  • tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py (1 hunks)
  • tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (10 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/speculative/drafter.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
  • tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tests/integration/defs/test_e2e.py
  • tensorrt_llm/_torch/speculative/spec_tree_manager.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/speculative/eagle3.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/speculative/drafter.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
  • tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tests/integration/defs/test_e2e.py
  • tensorrt_llm/_torch/speculative/spec_tree_manager.py
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/speculative/drafter.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
  • tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tests/integration/defs/test_e2e.py
  • tensorrt_llm/_torch/speculative/spec_tree_manager.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/speculative/eagle3.py
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (1)
📚 Learning: 2025-08-20T06:56:02.889Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.

Applied to files:

  • cpp/tensorrt_llm/thop/attentionOp.cpp
🧬 Code graph analysis (11)
tensorrt_llm/_torch/speculative/interface.py (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
  • TrtllmAttention (1172-1609)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • attn_metadata (124-125)
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (2)
tensorrt_llm/_torch/speculative/drafting_loops.py (3)
  • ChainDrafter (289-476)
  • forward (300-430)
  • sample (432-469)
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
  • SpecTreeManager (7-395)
tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py (3)
tensorrt_llm/_torch/speculative/drafting_loops.py (1)
  • prepare_for_generation_with_tree_decoding (110-286)
tensorrt_llm/_torch/speculative/eagle3.py (2)
  • Eagle3ResourceManager (23-109)
  • Eagle3SpecMetadata (113-266)
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
  • SpecTreeManager (7-395)
tensorrt_llm/_torch/attention_backend/interface.py (2)
cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h (1)
  • batch_size (167-167)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • spec_metadata (116-117)
tests/integration/defs/test_e2e.py (1)
tests/integration/defs/conftest.py (3)
  • llm_root (192-193)
  • llm_venv (702-719)
  • llm_models_root (80-94)
tensorrt_llm/_torch/speculative/drafting_loops.py (4)
tensorrt_llm/_torch/attention_backend/interface.py (10)
  • AttentionMetadata (43-347)
  • num_seqs (249-253)
  • seq_lens (171-172)
  • seq_lens (175-196)
  • seq_lens_cuda (219-220)
  • on_update (158-168)
  • num_contexts (199-200)
  • num_contexts (203-206)
  • num_tokens (271-272)
  • forward (605-628)
tensorrt_llm/_torch/speculative/eagle3.py (2)
  • Eagle3SpecMetadata (113-266)
  • forward (362-484)
tensorrt_llm/_torch/speculative/interface.py (1)
  • SpecMetadata (168-256)
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
  • SpecTreeManager (7-395)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
tensorrt_llm/runtime/generation.py (1)
  • max_draft_tokens (1319-1322)
tensorrt_llm/_torch/pyexecutor/model_engine.py (6)
tensorrt_llm/_torch/speculative/eagle3.py (2)
  • Eagle3ResourceManager (23-109)
  • Eagle3SpecMetadata (113-266)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (3)
  • ResourceManager (1203-1246)
  • get_resource_manager (1215-1216)
  • KVCacheManager (151-1154)
tensorrt_llm/llmapi/llm_args.py (1)
  • is_linear_tree (646-649)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)
  • spec_metadata (116-117)
  • attn_metadata (124-125)
tensorrt_llm/_torch/pyexecutor/scheduler.py (2)
  • ScheduledRequests (20-41)
  • batch_size (37-38)
tensorrt_llm/_torch/attention_backend/interface.py (1)
  • AttentionMetadata (43-347)
tensorrt_llm/_torch/attention_backend/trtllm.py (3)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • spec_metadata (116-117)
tensorrt_llm/_utils.py (1)
  • get_sm_version (732-734)
tensorrt_llm/_torch/speculative/interface.py (1)
  • is_eagle3 (38-39)
tensorrt_llm/_torch/speculative/eagle3.py (2)
tensorrt_llm/_torch/attention_backend/interface.py (2)
  • seq_lens (171-172)
  • seq_lens (175-196)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
  • get_slot (1164-1165)
🪛 Ruff (0.14.1)
tensorrt_llm/_torch/speculative/drafter.py

67-67: Found useless expression. Either assign it to a variable or remove it.

(B018)

tensorrt_llm/_torch/speculative/spec_tree_manager.py

16-16: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


226-227: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

tensorrt_llm/_torch/attention_backend/trtllm.py

1062-1062: Found useless expression. Either assign it to a variable or remove it.

(B018)


1063-1063: Found useless expression. Either assign it to a variable or remove it.

(B018)


1064-1064: Found useless expression. Either assign it to a variable or remove it.

(B018)


1118-1118: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)

tensorrt_llm/_torch/speculative/eagle3.py

186-186: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


197-197: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)


214-214: Consider [0, *accepted_path] instead of concatenation

(RUF005)

🔇 Additional comments (2)
tensorrt_llm/_torch/speculative/model_drafter.py (2)

576-581: Using max_total_draft_tokens for static tree is correct

Looping to max_total_draft_tokens aligns with tree semantics and buffer sizes.


170-173: Confirmed: field is properly initialized on all paths

Verification shows py_num_accepted_draft_tokens_indices is initialized in LlmRequest.__init__() at line 485 as self.py_num_accepted_draft_tokens_indices = []. This initialization applies to all instance creation paths:

  1. Direct instantiation: LlmRequest(request_id=...) goes through __init__
  2. Child copy: LlmRequest(llm_request=child) also goes through __init__

Since _create_draft_request() creates new requests via the constructor, all instances get the field initialized. The assignment at line 172 safely copies from the source request, which is guaranteed to have the field initialized. No AttributeError risk exists.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22532 [ run ] completed with state SUCCESS. Commit: ef8b2b6
/LLM/main/L0_MergeRequest_PR pipeline #16987 completed with status: 'FAILURE'

@yweng0828 yweng0828 requested review from a team as code owners October 27, 2025 09:24
@yweng0828 yweng0828 force-pushed the yweng/add_draft_token_tree_runtime_on_cdl branch from b17a837 to 655723b Compare October 27, 2025 09:27
@yweng0828 yweng0828 requested review from sunnyqgg and ziyixiong-nv and removed request for brb-nv and pcastonguay October 27, 2025 09:28
@yweng0828
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22626 [ run ] triggered by Bot. Commit: 655723b

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22626 [ run ] completed with state SUCCESS. Commit: 655723b
/LLM/main/L0_MergeRequest_PR pipeline #17056 completed with status: 'FAILURE'

@yweng0828
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22749 [ run ] triggered by Bot. Commit: 60c0caf

@yweng0828 yweng0828 requested a review from ixlmar October 28, 2025 08:52
@tensorrt-cicd
Copy link
Collaborator

PR_Github #22749 [ run ] completed with state SUCCESS. Commit: 60c0caf
/LLM/main/L0_MergeRequest_PR pipeline #17154 completed with status: 'FAILURE'

) # exclude the root node
return num_accepted_draft_tokens - 1

def _tree_sampling_batch(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ixlmar , sorry, I know you've made a lot of improvements to batched sampling.
But I decided to remove this function because we're currently implementing the draft token tree only in capturable drafting loops (CDLs) (which may has better performance). The corresponding tree sampling will only appear in the sample() function in drafting_loops.py (this approach is somewhat like a one-model).

Although we have also implemented the draft token tree for non-CDL (draft PR), and this version of the drafter requires calling _tree_sampling_batch() after each forward pass. However, I currently have no plans to merge it.

I could also keep this function for future use, but I'm not sure if it would introduce additional maintenance burden. So I'd like to hear your thoughts.

cc @mikeiovine

model_outputs = {
"logits": logits,
}
# Create the chain drafter
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we moved tree_sampling to sample() in drafting_loops.py, we need to modify these tests accordingly.

spec_dec_position_offsets: Optional[torch.Tensor] = None

# TODO: Optimized together with the subsequent dynamic tree.
# Auxiliary buffers for the static tree.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a lot of auxiliary variables for static trees. This is because the structure of static trees is fixed during inference, so we can reduce a lot of repeated calculations, and this will make the update logic simpler (for example, for packed mask/position offset, etc.).

accepted_draft_token_offsets, packed_accepted_draft_tokens_indices, rewind_draft_token_separate_adjustments = self.locate_accepted_draft_tokens(
requests)
past_key_value_lengths = attn_metadata.kv_lens_cuda
past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we only need slices len(requests) of data, otherwise there will be an error that the shape does not match.

@yweng0828
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22770 [ run ] triggered by Bot. Commit: 60c0caf

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22770 [ run ] completed with state SUCCESS. Commit: 60c0caf
/LLM/main/L0_MergeRequest_PR pipeline #17170 completed with status: 'FAILURE'

Signed-off-by: Yue Weng <[email protected]>
Signed-off-by: Yue Weng <[email protected]>
Signed-off-by: Yue Weng <[email protected]>
Signed-off-by: Yue Weng <[email protected]>
@yweng0828 yweng0828 force-pushed the yweng/add_draft_token_tree_runtime_on_cdl branch from 60c0caf to feb398c Compare October 29, 2025 02:54
@yweng0828
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22831 [ run ] triggered by Bot. Commit: feb398c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22831 [ run ] completed with state SUCCESS. Commit: feb398c
/LLM/main/L0_MergeRequest_PR pipeline #17222 completed with status: 'FAILURE'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants